{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 中文CLIP图片搜索演示"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:29:44.111854Z",
     "start_time": "2021-04-10T01:29:39.825817Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '6'\n",
    "import sys\n",
    "sys.path.insert(0, '..')\n",
    "from pathlib import Path\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "from clip.clip import build_model\n",
    "from clip.simple_tokenizer import SimpleCharTokenizer\n",
    "from clip.data import tokenize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:29:44.118443Z",
     "start_time": "2021-04-10T01:29:44.114179Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device: cuda\n"
     ]
    }
   ],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"device: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:30:10.252260Z",
     "start_time": "2021-04-10T01:29:44.121967Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "embed_dim: 512\n",
      "vocab_size: 8000\n",
      "text_projection shape: torch.Size([512, 512])\n"
     ]
    }
   ],
   "source": [
    "state_dict = torch.load('../data/clip_zh.mse.pt')\n",
    "\n",
    "model = build_model(state_dict).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:30:14.365772Z",
     "start_time": "2021-04-10T01:30:10.257818Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Photos loaded: 1981161\n"
     ]
    }
   ],
   "source": [
    "# Load the photo IDs\n",
    "photo_ids = pd.read_csv(\"../data/unsplash-dataset/photo_ids.csv\")\n",
    "photo_ids = list(photo_ids['photo_id'])\n",
    "\n",
    "# Load the features vectors\n",
    "photo_features = np.load(\"../data/unsplash-dataset/features.npy\")\n",
    "\n",
    "# Print some statistics\n",
    "print(f\"Photos loaded: {len(photo_ids)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:30:28.423634Z",
     "start_time": "2021-04-10T01:30:28.410329Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading vocab from ../data/zh_vocab.txt\n"
     ]
    }
   ],
   "source": [
    "zh_tokenizer = SimpleCharTokenizer('../data/zh_vocab.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:30:34.307504Z",
     "start_time": "2021-04-10T01:30:34.303558Z"
    }
   },
   "outputs": [],
   "source": [
    "def encode_search_query(search_query):\n",
    "    with torch.no_grad():\n",
    "        # Encode and normalize the search query using CLIP\n",
    "        text_encoded = model.encode_text(tokenize(zh_tokenizer, search_query).to(device))\n",
    "        text_encoded /= text_encoded.norm(dim=-1, keepdim=True)\n",
    "\n",
    "    # Retrieve the feature vector\n",
    "    return text_encoded.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:30:35.353420Z",
     "start_time": "2021-04-10T01:30:35.349550Z"
    }
   },
   "outputs": [],
   "source": [
    "def find_best_matches(text_features, photo_features, photo_ids, results_count=3):\n",
    "    # Compute the similarity between the search query and each photo using the Cosine similarity\n",
    "    similarities = (photo_features @ text_features.T).squeeze(1)\n",
    "\n",
    "    # Sort the photos by their similarity score\n",
    "    best_photo_idx = (-similarities).argsort()\n",
    "\n",
    "    # Return the photo IDs of the best matches\n",
    "    return [photo_ids[i] for i in best_photo_idx[:results_count]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:30:36.202868Z",
     "start_time": "2021-04-10T01:30:36.199084Z"
    }
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image\n",
    "from IPython.core.display import HTML\n",
    "\n",
    "def display_photo(photo_id):\n",
    "    # Get the URL of the photo resized to have a width of 320px\n",
    "    photo_image_url = f\"https://unsplash.com/photos/{photo_id}/download?w=320\"\n",
    "\n",
    "    # Display the photo\n",
    "    display(Image(url=photo_image_url))\n",
    "\n",
    "    # Display the attribution text\n",
    "    display(HTML(f'Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/{photo_id}\">Unsplash</a> '))\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 以文搜图"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:30:40.408440Z",
     "start_time": "2021-04-10T01:30:40.404513Z"
    }
   },
   "outputs": [],
   "source": [
    "def search_unslash(search_query, photo_features, photo_ids, results_count=3):\n",
    "    # Encode the search query\n",
    "    text_features = encode_search_query(search_query)\n",
    "\n",
    "    # Find the best matches\n",
    "    best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, results_count)\n",
    "\n",
    "    # Display the best photos\n",
    "    for photo_id in best_photo_ids:\n",
    "        display_photo(photo_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:30:47.937215Z",
     "start_time": "2021-04-10T01:30:40.938290Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/jOilwGNg-YE/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/jOilwGNg-YE\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/BxHe-dKdxmE/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/BxHe-dKdxmE\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/cBAcMcwQ1h0/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/cBAcMcwQ1h0\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "search_query = \"两只狗在雪地里玩耍\"\n",
    "\n",
    "search_unslash(search_query, photo_features, photo_ids, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:30:58.917054Z",
     "start_time": "2021-04-10T01:30:51.809258Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/Z_ZyLRsjIBI/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/Z_ZyLRsjIBI\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/mKEWMUmFgHU/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/mKEWMUmFgHU\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/HWUTM_DW6nM/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/HWUTM_DW6nM\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "search_query = \"写在墙上的爱\"\n",
    "\n",
    "search_unslash(search_query, photo_features, photo_ids, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:31:05.923309Z",
     "start_time": "2021-04-10T01:30:58.919637Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/UMlT0bviaek/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/UMlT0bviaek\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/oFBw2y1T5wY/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/oFBw2y1T5wY\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/pMnw5BSZYsA/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/pMnw5BSZYsA\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "search_query = \"当你的代码终于成功时的感受\"\n",
    "\n",
    "search_unslash(search_query, photo_features, photo_ids, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:31:12.929963Z",
     "start_time": "2021-04-10T01:31:05.925648Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/beYNO-kn4-Y/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/beYNO-kn4-Y\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/wAThchhhCRU/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/wAThchhhCRU\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/OAzRalQb0F4/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/OAzRalQb0F4\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "search_query = \"晚上的悉尼歌剧院和哈勃桥\"\n",
    "\n",
    "search_unslash(search_query, photo_features, photo_ids, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:31:20.093005Z",
     "start_time": "2021-04-10T01:31:12.932150Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/Qbk8icsUeVo/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/Qbk8icsUeVo\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/nUZDeq8IsiI/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/nUZDeq8IsiI\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/m7dwbqqmrJA/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/m7dwbqqmrJA\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "search_query = \"悉尼歌剧院和蓝蓝的天空\"\n",
    "\n",
    "search_unslash(search_query, photo_features, photo_ids, 3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 以图搜图"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:31:47.920977Z",
     "start_time": "2021-04-10T01:31:47.914426Z"
    }
   },
   "outputs": [],
   "source": [
    "def search_by_photo(query_photo_id, photo_ids, query_num=3):\n",
    "\n",
    "    # Find the feature vector for the specified photo ID\n",
    "    query_photo_index = photo_ids.index(query_photo_id)\n",
    "    query_photo_features = np.expand_dims(photo_features[query_photo_index], axis=0)\n",
    "    print(query_photo_features.shape)\n",
    "\n",
    "    # Find the best match\n",
    "    best_photo_ids = find_best_matches(query_photo_features, photo_features, photo_ids, query_num+1)\n",
    "\n",
    "    for photo_id in best_photo_ids:\n",
    "        if photo_id != query_photo_id:\n",
    "            display_photo(photo_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:31:48.923019Z",
     "start_time": "2021-04-10T01:31:48.917810Z"
    }
   },
   "outputs": [],
   "source": [
    "def show_img_by_id(photo_id):\n",
    "    display(Image(url=f\"https://unsplash.com/photos/{photo_id}/download?w=320\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:31:49.704629Z",
     "start_time": "2021-04-10T01:31:49.699983Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/QP-l1uE19iI/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_img_by_id('QP-l1uE19iI')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:32:01.656192Z",
     "start_time": "2021-04-10T01:31:50.900904Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 512)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/UnxClUska98/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/UnxClUska98\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/2gwhorA9tic/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/2gwhorA9tic\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/46sXIKactww/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/46sXIKactww\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "search_by_photo('QP-l1uE19iI', photo_ids, 3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 图文混合搜图"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:32:01.669730Z",
     "start_time": "2021-04-10T01:32:01.661319Z"
    }
   },
   "outputs": [],
   "source": [
    "def search_by_text_and_photo(query_text, query_photo_id, photo_weight=0.2):\n",
    "    # Encode the search query\n",
    "    text_features = encode_search_query(query_text)\n",
    "\n",
    "    # Find the feature vector for the specified photo ID\n",
    "    query_photo_index = photo_ids.index(query_photo_id)\n",
    "    query_photo_features = photo_features[query_photo_index]\n",
    "\n",
    "    # Combine the test and photo queries and normalize again\n",
    "    search_features = text_features + query_photo_features * photo_weight\n",
    "    search_features = torch.from_numpy(search_features)\n",
    "    search_features /= search_features.norm(dim=-1, keepdim=True)\n",
    "    search_features = search_features.numpy()\n",
    "\n",
    "    # Find the best match\n",
    "    best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 1)\n",
    "\n",
    "    # Display the results\n",
    "    print(\"Test search result\")\n",
    "    search_unslash(query_text, photo_features, photo_ids, 1)\n",
    "\n",
    "    print(\"Photo query\")\n",
    "    display(Image(url=f\"https://unsplash.com/photos/{query_photo_id}/download?w=320\"))\n",
    "\n",
    "    print(\"Result for text query + photo query\")\n",
    "    display_photo(best_photo_ids[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:32:14.318745Z",
     "start_time": "2021-04-10T01:32:01.672787Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test search result\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/TZxIwXnwGC0/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/TZxIwXnwGC0\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Photo query\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/MaerUPAjPbs/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result for text query + photo query\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/GSXMuGwY2l4/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/GSXMuGwY2l4\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "search_by_text_and_photo(\"悉尼歌剧院\", \"MaerUPAjPbs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:32:27.645199Z",
     "start_time": "2021-04-10T01:32:14.320490Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test search result\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/TZxIwXnwGC0/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/TZxIwXnwGC0\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Photo query\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/1pNBJ2zUfn4/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result for text query + photo query\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/TsRFPlkUkr4/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/TsRFPlkUkr4\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "search_by_text_and_photo(\"悉尼歌剧院\", \"1pNBJ2zUfn4\", 0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-04-10T01:32:41.345246Z",
     "start_time": "2021-04-10T01:32:27.647216Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test search result\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/TZxIwXnwGC0/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/TZxIwXnwGC0\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Photo query\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/jnBDclcdZ7A/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result for text query + photo query\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://unsplash.com/photos/IFFrV0_bR6U/download?w=320\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Photo on <a target=\"_blank\" href=\"https://unsplash.com/photos/IFFrV0_bR6U\">Unsplash</a> "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "search_by_text_and_photo(\"悉尼歌剧院\", \"jnBDclcdZ7A\", 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:torch171]",
   "language": "python",
   "name": "conda-env-torch171-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
